{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bert_multitask_learning import train_bert_multitask, train_eval_input_fn, BertMultiTask, params\n",
    "from bert_multitask_learning.predefined_problems import get_weibo_ner_fn, get_weibo_cws_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "problem_type_dict = {\n",
    "    'weibo_cws': 'seq_tag',\n",
    "    'weibo_ner': 'seq_tag'\n",
    "}\n",
    "\n",
    "\n",
    "processing_fn_dict = {\n",
    "    'weibo_ner': get_weibo_ner_fn(file_path='../data/ner/weiboNER*'),\n",
    "    'weibo_cws': get_weibo_cws_fn(file_path='../data/ner/weiboNER*')\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Models\n",
    "If you don't want to control every thing, you can just call `train_bert_multitask` function. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = params.DynamicBatchSizeParams()\n",
    "params.bert_num_hidden_layer = 1\n",
    "train_bert_multitask(problem='weibo_ner&weibo_cws', problem_type_dict=problem_type_dict, processing_fn_dict=processing_fn_dict, num_gpus=1, num_epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to take more control of the training process, you can use lower level api"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.estimator import Estimator\n",
    "from bert_multitask_learning.ckpt_restore_hook import RestoreCheckpointHook\n",
    "\n",
    "problem = 'weibo_ner&weibo_cws'\n",
    "num_gpus = 1\n",
    "bert_multitask_params = params.DynamicBatchSizeParams()\n",
    "params.bert_num_hidden_layer = 1\n",
    "\n",
    "for new_problem, new_problem_processing_fn in processing_fn_dict.items():\n",
    "    print('Adding new problem {0}, problem type: {1}'.format(\n",
    "        new_problem, problem_type_dict[new_problem]))\n",
    "    params.add_problem(\n",
    "        problem_name=new_problem, problem_type=problem_type_dict[new_problem], processing_fn=new_problem_processing_fn)\n",
    "\n",
    "# assign problem to params\n",
    "bert_multitask_params.train_epoch = 1\n",
    "bert_multitask_params.assign_problem(problem, gpu=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get model fn and create mirror strategy for distributed training\n",
    "model = BertMultiTask(params=bert_multitask_params)\n",
    "model_fn = model.get_model_fn()\n",
    "\n",
    "dist_trategy = tf.contrib.distribute.MirroredStrategy(\n",
    "    num_gpus=int(num_gpus),\n",
    "    cross_tower_ops=tf.contrib.distribute.AllReduceCrossDeviceOps(\n",
    "        'nccl', num_packs=int(num_gpus)))\n",
    "\n",
    "run_config = tf.estimator.RunConfig(\n",
    "    train_distribute=dist_trategy,\n",
    "    eval_distribute=dist_trategy,\n",
    "    log_step_count_steps=bert_multitask_params.log_every_n_steps)\n",
    "\n",
    "# create estimator\n",
    "estimator = Estimator(\n",
    "    model_fn,\n",
    "    model_dir=bert_multitask_params.ckpt_dir,\n",
    "    params=bert_multitask_params,\n",
    "    config=run_config)\n",
    "\n",
    "# pretrained bert restore hook\n",
    "train_hook = RestoreCheckpointHook(bert_multitask_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train\n",
    "def train_input_fn(): return train_eval_input_fn(bert_multitask_params)\n",
    "estimator.train(\n",
    "    train_input_fn, max_steps=bert_multitask_params.train_steps, hooks=[train_hook])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate and Predict\n",
    "\n",
    "For NER and CWS, we need different evaluation logic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bert_multitask_learning import eval_bert_multitask, predict_bert_multitask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_bert_multitask(problem='weibo_cws', model_dir='models/weibo_cws_weibo_ner_ckpt/', eval_scheme='acc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_bert_multitask(problem='weibo_ner', model_dir='models/weibo_cws_weibo_ner_ckpt/', eval_scheme='ner')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# predict\n",
    "import numpy as np\n",
    "from bert_multitask_learning.utils import get_or_make_label_encoder\n",
    "predict_params = params.DynamicBatchSizeParams()\n",
    "# get prediction generator\n",
    "pred_prob = predict_bert_multitask(inputs=['中国和美国在打贸易战'], problem='weibo_cws&weibo_ner', params=predict_params)\n",
    "# get label encoder\n",
    "ner_label_encoder = get_or_make_label_encoder(params=predict_params, problem='weibo_ner', mode='predict')\n",
    "cws_label_encoder = get_or_make_label_encoder(params=predict_params, problem='weibo_cws', mode='predict')\n",
    "\n",
    "for prob in pred_prob:\n",
    "    ner_pred = np.argmax(prob['weibo_ner'], axis = -1)\n",
    "    print(ner_label_encoder.inverse_transform(ner_pred.tolist()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export Model for Serving\n",
    "\n",
    "You can export the trained model for [serving](https://github.com/JayYip/bert-as-service)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bert_multitask_learning import export_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "export_model(bert_multitask_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}